# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
from functools import partial
from types import MethodType
from typing import Any, Dict, Iterator, List

import datasets
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from datasets import Dataset
from packaging import version
from torch.distributed.device_mesh import init_device_mesh
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Sampler

from swift.llm import DataLoaderDispatcher, DataLoaderShard, get_llm_model, to_device
from swift.utils import get_current_device, get_device, get_dist_setting, seed_worker
from .base import CommonSequenceParallel
from .utils import (
    ChunkedCrossEntropyLoss,
    GatherLoss,
    SequenceParallelDispatcher,
    SequenceParallelSampler,
    _get_per_token_logps_grpo,
    _get_train_sampler_grpo,
    _prepare_inputs,
    _prepare_inputs_grpo,
    get_common_dataloader,
    get_per_token_logps,
    loss_scale_sp_func,
    old_policy_grpo,
    setup_compute_acc,
    split_by_mini_batches_grpo,
)

RING_ATTN_GROUP = None


def set_ring_attn_group(group):
    global RING_ATTN_GROUP
    RING_ATTN_GROUP = group


def get_ring_attn_group():
    return RING_ATTN_GROUP


def reset_ring_attn_position_ids(start, end, packed_seq_lens):
    """
    Calculate position ids for packed_seq_ids[start:end].
    For example, if the packed_seq_lens is [3, 2, 4, 1], start=2, end=8,
    the position ids will be [2, 0, 1, 0, 1, 2].

    Args:
        start: the start position
        end: the end position
        packed_seq_lens: the sequence lengths of packed sequences
    """
    position_ids = torch.zeros(
        (1, end - start), dtype=torch.long, device=torch.cuda.current_device()
    )
    offset = 0
    for seqlen in packed_seq_lens:
        seq_start = max(offset, start)
        seq_end = min(offset + seqlen, end)
        if seq_start < seq_end:
            position_ids[0, seq_start - start : seq_end - start] = torch.arange(
                seq_start - offset, seq_end - offset
            )

        offset += seqlen
        if offset >= end:
            break
    return position_ids


def update_ring_attn_params(packed_seq_lens, total_seq_len):
    """
    Calculate the cu_seqlens for the current forward pass and pass the value to
    the substituted ring_flash_attn.
    """
    assert RING_ATTN_GROUP is not None
    cu_seqlens = torch.cumsum(
        torch.tensor(
            packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
        ),
        dim=-1,
        dtype=torch.int32,
    )
    cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)

    from ring_flash_attn import update_ring_flash_attn_params

    update_ring_flash_attn_params(cu_seqlens, RING_ATTN_GROUP)


def infer_packed_seq_lens_from_position_ids(position_ids):
    """
    Infer packed sequence lengths from position_ids.

    For packed sequences, position_ids are flattened. For example:
    - Original sequence lengths: [2, 3]
    - Position_ids: [0, 1, 0, 1, 2]

    Args:
        position_ids: torch.Tensor of shape (batch_size, seq_len) or (seq_len,)

    Returns:
        List of sequence lengths for each packed sequence
    """
    if position_ids.dim() == 2:
        # Handle batch dimension - assume batch_size=1 for packed sequences
        position_ids = position_ids.squeeze(0)

    position_ids = position_ids.cpu().tolist()
    packed_seq_lens = []
    current_seq_len = 0

    for i, pos_id in enumerate(position_ids):
        current_seq_len += 1
        # When position_id resets to 0 (except for the first token), it indicates a new sequence
        if i > 0 and pos_id == 0:
            packed_seq_lens.append(current_seq_len - 1)
            current_seq_len = 1

    # Add the last sequence length
    if current_seq_len > 0:
        packed_seq_lens.append(current_seq_len)

    return packed_seq_lens


class RingAttention(CommonSequenceParallel):

    def __init__(self):
        """Initialize RingAttention sequence parallel implementation."""
        super().__init__()

    def init_sequence_parallel(self, size):
        """Initialize ring attention sequence parallel with given size.

        Args:
            size: The sequence parallel world size
        """
        if self._inited:
            return
        self._inited = True

        self.sp_world_size = size
        self._init_device_mesh()

        # Set global ring attention group using the sequence dimension
        ring_attn_group = self.device_mesh["sequence"].get_group()
        set_ring_attn_group(ring_attn_group)

        # Import and setup ring flash attention
        ring_head_stride = int(os.environ.get("RING_HEAD_STRIDE", 2))
        try:
            from ring_flash_attn import substitute_hf_flash_attn

            # Substitute HuggingFace flash attention with ring attention
            substitute_hf_flash_attn(ring_attn_group, ring_head_stride)
        except ImportError:
            raise ImportError(
                "ring-flash-attn is required for RingAttention. "
                "Please install it with: pip install ring-flash-attn"
            )

    def prepare_model(self, model, tokenizer):
        """Prepare the model for ring attention sequence parallel.

        Args:
            model: The model to prepare
            tokenizer: The tokenizer to use
        """

        def pre_forward_hook(_self, args, kwargs):
            """Hook to process inputs before forward pass for ring attention."""
            input_ids = kwargs.get("input_ids", None)
            inputs_embeds = kwargs.get("inputs_embeds", None)
            position_ids = kwargs.get("position_ids", None)
            attention_mask = kwargs.get("attention_mask", None)

            # packed_seq_lens is calculated in data_collator when padding_free/packing mode is enabled
            packed_seq_lens = kwargs.get("packed_seq_lens", None)

            if packed_seq_lens is None:
                packed_seq_lens = infer_packed_seq_lens_from_position_ids(position_ids)

            # Get embed_tokens for padding
            if hasattr(_self, "language_model"):
                embed_tokens = getattr(_self.language_model, "embed_tokens", None)
            else:
                embed_tokens = getattr(_self, "embed_tokens", None)

            input_ids, inputs_embeds, _, position_ids, attention_mask, _ = (
                self.pad_and_split_inputs(
                    input_ids,
                    inputs_embeds,
                    None,
                    position_ids,
                    attention_mask,
                    None,
                    embed_tokens=embed_tokens,
                )
            )

            kwargs["input_ids"] = input_ids
            kwargs["inputs_embeds"] = inputs_embeds
            kwargs["position_ids"] = position_ids
            kwargs["attention_mask"] = attention_mask

            total_seq_len = position_ids.numel() * self.sp_world_size

            update_ring_attn_params(packed_seq_lens, total_seq_len)

            return args, kwargs

        # Get the base model to register the hook
        llm_model = get_llm_model(model)

        if hasattr(llm_model, "thinker"):
            base_model = llm_model.thinker.model
        else:
            base_model = llm_model.model

        # Register the pre-forward hook
        base_model.register_forward_pre_hook(pre_forward_hook, with_kwargs=True)

        # Store model dtype and tokenizer
        self.model_dtype = next(model.parameters()).dtype
        self.tokenizer = tokenizer

    def get_dataloader(self, trainer, dataset, batch_size, skip_batches: int = 0):
        return get_common_dataloader(
            self,
            trainer,
            dataset,
            batch_size,
            SequenceParallelSampler,
            SequenceParallelDispatcher,
            skip_batches=skip_batches,
        )

    def prepare_trainer(self, trainer):
        """Prepare trainer for ring attention sequence parallel.

        Args:
            trainer: The trainer to prepare
        """
        if trainer.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        trainer.ring_attention = self

        if trainer.__class__.__name__ == "Seq2SeqTrainer":
            trainer._origin_prepare_inputs = trainer._prepare_inputs
            trainer._prepare_inputs = MethodType(
                partial(_prepare_inputs, sp_instance=self), trainer
            )
            trainer.compute_loss_func = partial(loss_scale_sp_func, sp_instance=self)

        elif trainer.__class__.__name__ == "DPOTrainer":
            trainer._origin_prepare_inputs = trainer._prepare_inputs
            trainer._prepare_inputs = MethodType(
                partial(_prepare_inputs, sp_instance=self), trainer
            )
            trainer.get_per_token_logps = partial(get_per_token_logps, sp_instance=self)

        elif trainer.__class__.__name__ == "GRPOTrainer":
            try:
                import trl

                assert version.parse(trl.__version__) >= version.parse("0.18.0")
            except (ImportError, AssertionError):
                raise ImportError(
                    "trl>=0.18.0 is required for GRPOTrainer with ring attention. "
                    "Please install it with: pip install trl>=0.18.0"
                )

            trainer.ring_attention = self
            trainer.args.gradient_accumulation_steps = (
                trainer.args.gradient_accumulation_steps * self.sp_world_size
            )
            trainer.old_policy = MethodType(
                partial(old_policy_grpo, sp_instance=self), trainer
            )
            trainer._get_train_sampler = MethodType(
                partial(_get_train_sampler_grpo, sp_instance=self), trainer
            )
            trainer._prepare_inputs = MethodType(
                partial(_prepare_inputs_grpo, sp_instance=self), trainer
            )
            trainer._get_per_token_logps = MethodType(
                partial(_get_per_token_logps_grpo, sp_instance=self), trainer
            )
            trainer.split_by_mini_batches = MethodType(
                partial(split_by_mini_batches_grpo, sp_instance=self), trainer
            )

        setup_compute_acc(self)
